{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "unable to import 'smart_open.gcs', disabling that module\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import time\n",
    "\n",
    "from utils import *\n",
    "from data import *\n",
    "from models import *\n",
    "\n",
    "# torch.autograd.set_detect_anomaly(True)\n",
    "torch.set_num_threads(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ********* MODIFY HERE *********\n",
    "args = {\n",
    "    'model_read_ckpt': './ckpts/model_to_read',\n",
    "    'lm_emb_path': 'albert-xxlarge-v1', # language model name\n",
    "    'pretrained_wv': './wv/glove.6B.100d.txt', # original GloVe embeddings\n",
    "    'vocab_size': 400100, # GloVe contains 400,000 words\n",
    "    'device': 'cpu',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args['device'] is not None and args['device'] != 'cpu':\n",
    "    torch.cuda.set_device(args['device'])\n",
    "elif args['device'] is None:\n",
    "    if torch.cuda.is_available():\n",
    "        gpu_idx, gpu_mem = set_max_available_gpu()\n",
    "        args['device'] = f\"cuda:{gpu_idx}\"\n",
    "    else:\n",
    "        args['device'] = \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the config file of trained ckpt\n",
    "with open(args['model_read_ckpt']+'.json', 'r') as f:\n",
    "    config = Config(**json.load(f))\n",
    "    \n",
    "    # load language model to dynamically calculate the contextualized word embeddings\n",
    "    config.lm_emb_path = args['lm_emb_path']\n",
    "    # assign device\n",
    "    config.device = args['device']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jue_wang/anaconda3/lib/python3.7/site-packages/torch/nn/modules/rnn.py:50: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n",
      "  \"num_layers={}\".format(dropout, num_layers))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "albert-xxlarge-v1 is not file, try load as bert model.\n",
      "albert-xxlarge-v1 loaded successfully.\n",
      "Note it only supports default options now, i.e.: \n",
      "  layers='-1,-2,-3,-4', use_scalar_mix=True, pooling_operation=\"mean\"\n"
     ]
    }
   ],
   "source": [
    "model = JointModel(config)\n",
    "model.load_ckpt(args['model_read_ckpt'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "400001it [00:16, 24582.56it/s]\n"
     ]
    }
   ],
   "source": [
    "# Load full GloVe embeddings \n",
    "# *this is needed when training on the reduced version of GloVe word vectors \n",
    "# *you can comment this block if the OOV problem is not very serious.\n",
    "_w = model.token_embedding.token_embedding.weight \n",
    "_w_data = _w.data\n",
    "_w.data = torch.zeros([args['vocab_size'], config.token_emb_dim], dtype=_w.dtype, device=_w.device)\n",
    "model.token_embedding.load_pretrained(args['pretrained_wv'], freeze=True)\n",
    "_w.data[:len(_w_data)] = _w_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now predict on custom text\n",
    "rets = model.predict_step({\n",
    "    'tokens': [['My', 'name', 'is', 'Jackson', ',', 'and', 'I', 'live', 'in', 'Berlin', '.'],]\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[(0, 1, 'PER'), (3, 4, 'PER'), (6, 7, 'PER'), (9, 10, 'GPE')]]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rets['entity_preds']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{(6, 7, 9, 10, 'PHYS')}]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rets['relation_preds']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
